{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xqLjB2cy5S7m"
   },
   "source": [
    "## MNIST in Keras with Tensorboard\n",
    "\n",
    "This sample trains an \"MNIST\" handwritten digit \n",
    "recognition model on a GPU or TPU backend using a Keras\n",
    "model. Data are handled using the tf.data.Datset API. This is\n",
    "a very simple sample provided for educational purposes. Do\n",
    "not expect outstanding TPU performance on a dataset as\n",
    "small as MNIST."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lvo0t7XVIkWZ"
   },
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cCpkS9C_H7Tl"
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "LEARNING_RATE = 0.002\n",
    "# GCS bucket for training logs and for saving the trained model\n",
    "# You can leave this empty for local saving, unless you are using a TPU.\n",
    "# TPUs do not have access to your local instance and can only write to GCS.\n",
    "BUCKET=\"\" # a valid bucket name must start with gs://\n",
    "\n",
    "training_images_file   = 'gs://mnist-public/train-images-idx3-ubyte'\n",
    "training_labels_file   = 'gs://mnist-public/train-labels-idx1-ubyte'\n",
    "validation_images_file = 'gs://mnist-public/t10k-images-idx3-ubyte'\n",
    "validation_labels_file = 'gs://mnist-public/t10k-labels-idx1-ubyte'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qpiJj8ym0v0-"
   },
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "AoilhmYe1b5t",
    "outputId": "914f12e4-ca4e-4b92-ddf5-acdf57c0b13b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensorflow version 2.2.0-dlenv\n"
     ]
    }
   ],
   "source": [
    "import os, re, math, json, time\n",
    "import PIL.Image, PIL.ImageFont, PIL.ImageDraw\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from matplotlib import pyplot as plt\n",
    "from tensorflow.python.platform import tf_logging\n",
    "print(\"Tensorflow version \" + tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TPU/GPU detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Initializing the TPU system: martin-tpuv3-8-tf22\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Initializing the TPU system: martin-tpuv3-8-tf22\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Clearing out eager caches\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Clearing out eager caches\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Finished initializing TPU system.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Finished initializing TPU system.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Found TPU system:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Found TPU system:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores: 8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Workers: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Workers: 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of accelerators:  8\n"
     ]
    }
   ],
   "source": [
    "try: # detect TPUs\n",
    "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
    "  tf.config.experimental_connect_to_cluster(tpu)\n",
    "  tf.tpu.experimental.initialize_tpu_system(tpu)\n",
    "  strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
    "except ValueError: # detect GPUs\n",
    "  strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines\n",
    "  #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU\n",
    "  #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines\n",
    "\n",
    "print(\"Number of accelerators: \", strategy.num_replicas_in_sync)\n",
    "    \n",
    "# adjust batch size and learning rate for distributed computing\n",
    "global_batch_size = BATCH_SIZE * strategy.num_replicas_in_sync # num replcas is 8 on a single TPU or N when runing on N GPUs.\n",
    "learning_rate = LEARNING_RATE * strategy.num_replicas_in_sync"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "qhdz68Xm3Z4Z"
   },
   "outputs": [],
   "source": [
    "#@title visualization utilities [RUN ME]\n",
    "\"\"\"\n",
    "This cell contains helper functions used for visualization\n",
    "and downloads only. You can skip reading it. There is very\n",
    "little useful Keras/Tensorflow code here.\n",
    "\"\"\"\n",
    "\n",
    "# Matplotlib config\n",
    "plt.rc('image', cmap='gray_r')\n",
    "plt.rc('grid', linewidth=0)\n",
    "plt.rc('xtick', top=False, bottom=False, labelsize='large')\n",
    "plt.rc('ytick', left=False, right=False, labelsize='large')\n",
    "plt.rc('axes', facecolor='F8F8F8', titlesize=\"large\", edgecolor='white')\n",
    "plt.rc('text', color='a8151a')\n",
    "plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts\n",
    "MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), \"mpl-data/fonts/ttf\")\n",
    "\n",
    "# pull a batch from the datasets. This code is not very nice, it gets much better in eager mode (TODO)\n",
    "def dataset_to_numpy_util(training_dataset, validation_dataset, N):\n",
    "  \n",
    "  # get one batch from each: 10000 validation digits, N training digits\n",
    "  unbatched_train_ds = training_dataset.unbatch()\n",
    "  \n",
    "\n",
    "  # This is the TF 2.0 \"eager execution\" way of iterating through a tf.data.Dataset\n",
    "  for v_images, v_labels in validation_dataset:\n",
    "    break\n",
    "\n",
    "  for t_images, t_labels in unbatched_train_ds.batch(N):\n",
    "    break\n",
    "\n",
    "  validation_digits = v_images.numpy()\n",
    "  validation_labels = v_labels.numpy()\n",
    "  training_digits   = t_images.numpy()\n",
    "  training_labels   = t_labels.numpy()\n",
    "\n",
    "  \n",
    "  # these were one-hot encoded in the dataset\n",
    "  validation_labels = np.argmax(validation_labels, axis=1)\n",
    "  training_labels = np.argmax(training_labels, axis=1)\n",
    "  \n",
    "  return (training_digits, training_labels,\n",
    "          validation_digits, validation_labels)\n",
    "\n",
    "# create digits from local fonts for testing\n",
    "def create_digits_from_local_fonts(n):\n",
    "  font_labels = []\n",
    "  img = PIL.Image.new('LA', (28*n, 28), color = (0,255)) # format 'LA': black in channel 0, alpha in channel 1\n",
    "  font1 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'DejaVuSansMono-Oblique.ttf'), 25)\n",
    "  font2 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'STIXGeneral.ttf'), 25)\n",
    "  d = PIL.ImageDraw.Draw(img)\n",
    "  for i in range(n):\n",
    "    font_labels.append(i%10)\n",
    "    d.text((7+i*28,0 if i<10 else -4), str(i%10), fill=(255,255), font=font1 if i<10 else font2)\n",
    "  font_digits = np.array(img.getdata(), np.float32)[:,0] / 255.0 # black in channel 0, alpha in channel 1 (discarded)\n",
    "  font_digits = np.reshape(np.stack(np.split(np.reshape(font_digits, [28, 28*n]), n, axis=1), axis=0), [n, 28*28])\n",
    "  return font_digits, font_labels\n",
    "\n",
    "# utility to display a row of digits with their predictions\n",
    "def display_digits(digits, predictions, labels, title, n):\n",
    "  plt.figure(figsize=(13,3))\n",
    "  digits = np.reshape(digits, [n, 28, 28])\n",
    "  digits = np.swapaxes(digits, 0, 1)\n",
    "  digits = np.reshape(digits, [28, 28*n])\n",
    "  plt.yticks([])\n",
    "  plt.xticks([28*x+14 for x in range(n)], predictions)\n",
    "  for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):\n",
    "    if predictions[i] != labels[i]: t.set_color('red') # bad predictions in red\n",
    "  plt.imshow(digits)\n",
    "  plt.grid(None)\n",
    "  plt.title(title)\n",
    "  \n",
    "# utility to display multiple rows of digits, sorted by unrecognized/recognized status\n",
    "def display_top_unrecognized(digits, predictions, labels, n, lines):\n",
    "  idx = np.argsort(predictions==labels) # sort order: unrecognized first\n",
    "  for i in range(lines):\n",
    "    display_digits(digits[idx][i*n:(i+1)*n], predictions[idx][i*n:(i+1)*n], labels[idx][i*n:(i+1)*n],\n",
    "                   \"{} sample validation digits out of {} with bad predictions in red and sorted first\".format(n*lines, len(digits)) if i==0 else \"\", n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lzd6Qi464PsA"
   },
   "source": [
    "### Colab-only auth for this notebook and the TPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "id": "MPx0nvyUnvgT"
   },
   "outputs": [],
   "source": [
    "#IS_COLAB_BACKEND = 'COLAB_GPU' in os.environ  # this is always set on Colab, the value is 0 or 1 depending on GPU presence\n",
    "#if IS_COLAB_BACKEND:\n",
    "#  from google.colab import auth\n",
    "#  auth.authenticate_user() # Authenticates the backend and also the TPU using your credentials so that they can access your private GCS buckets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lz1Zknfk4qCx"
   },
   "source": [
    "### tf.data.Dataset: parse files and prepare training and validation datasets\n",
    "Please read the [best practices for building](https://www.tensorflow.org/guide/performance/datasets) input pipelines with tf.data.Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZE8dgyPC1_6m"
   },
   "outputs": [],
   "source": [
    "def read_label(tf_bytestring):\n",
    "    label = tf.io.decode_raw(tf_bytestring, tf.uint8)\n",
    "    label = tf.reshape(label, [])\n",
    "    label = tf.one_hot(label, 10)\n",
    "    return label\n",
    "  \n",
    "def read_image(tf_bytestring):\n",
    "    image = tf.io.decode_raw(tf_bytestring, tf.uint8)\n",
    "    image = tf.cast(image, tf.float32)/256.0\n",
    "    image = tf.reshape(image, [28*28])\n",
    "    return image\n",
    "  \n",
    "def load_dataset(image_file, label_file):\n",
    "    imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28, header_bytes=16)\n",
    "    imagedataset = imagedataset.map(read_image, num_parallel_calls=16)\n",
    "    labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1, header_bytes=8)\n",
    "    labelsdataset = labelsdataset.map(read_label, num_parallel_calls=16)\n",
    "    dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))\n",
    "    return dataset \n",
    "  \n",
    "def get_training_dataset(image_file, label_file, batch_size):\n",
    "    dataset = load_dataset(image_file, label_file)\n",
    "    dataset = dataset.cache()  # this small dataset can be entirely cached in RAM, for TPU this is important to get good performance from such a small dataset\n",
    "    dataset = dataset.shuffle(5000, reshuffle_each_iteration=True)\n",
    "    dataset = dataset.repeat() # Mandatory for Keras for now\n",
    "    dataset = dataset.batch(batch_size, drop_remainder=True) # drop_remainder is important on TPU, batch size must be fixed\n",
    "    dataset = dataset.prefetch(-1)  # fetch next batches while training on the current one (-1: autotune prefetch buffer size)\n",
    "    return dataset\n",
    "  \n",
    "def get_validation_dataset(image_file, label_file):\n",
    "    dataset = load_dataset(image_file, label_file)\n",
    "    dataset = dataset.cache() # this small dataset can be entirely cached in RAM, for TPU this is important to get good performance from such a small dataset\n",
    "    dataset = dataset.repeat() # Mandatory for Keras for now\n",
    "    dataset = dataset.batch(10000, drop_remainder=True) # 10000 items in eval dataset, all in one batch\n",
    "    return dataset\n",
    "\n",
    "# instantiate the datasets\n",
    "training_dataset = get_training_dataset(training_images_file, training_labels_file, global_batch_size)\n",
    "validation_dataset = get_validation_dataset(validation_images_file, validation_labels_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_fXo6GuvL3EB"
   },
   "source": [
    "### Let's have a look at the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 177
    },
    "colab_type": "code",
    "id": "yZ4tjPKvL2eh",
    "outputId": "11c2414a-ab78-4716-b3e1-5942a90239c6"
   },
   "outputs": [],
   "source": [
    "N = 24\n",
    "(training_digits, training_labels,\n",
    " validation_digits, validation_labels) = dataset_to_numpy_util(training_dataset, validation_dataset, N)\n",
    "display_digits(training_digits, training_labels, training_labels, \"training digits and their labels\", N)\n",
    "display_digits(validation_digits[:N], validation_labels[:N], validation_labels[:N], \"validation digits and their labels\", N)\n",
    "font_digits, font_labels = create_digits_from_local_fonts(N)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KIc0oqiD40HC"
   },
   "source": [
    "### Keras model: 3 convolutional layers, 2 dense layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 672
    },
    "colab_type": "code",
    "id": "56y8UNFQIVwj",
    "outputId": "9881b784-7da6-406c-8756-e1b9b71ec5c7"
   },
   "outputs": [],
   "source": [
    "# This model trains to 99.4%— sometimes 99.5%— accuracy in 10 epochs (with a batch size of 64)\n",
    "\n",
    "def make_model():\n",
    "    \n",
    "    model = tf.keras.Sequential(\n",
    "      [\n",
    "        tf.keras.layers.Reshape(input_shape=(28*28,), target_shape=(28, 28, 1)),\n",
    "\n",
    "        tf.keras.layers.Conv2D(filters=6, kernel_size=3, padding='same', use_bias=True, activation='relu'),\n",
    "        tf.keras.layers.Conv2D(filters=12, kernel_size=6, padding='same', use_bias=True, activation='relu', strides=2),\n",
    "        tf.keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', use_bias=True, activation='relu', strides=2),\n",
    "        tf.keras.layers.Flatten(),\n",
    "        tf.keras.layers.Dense(200, use_bias=True, activation='relu'),\n",
    "        tf.keras.layers.Dense(10, activation='softmax')\n",
    "      ])\n",
    "\n",
    "    model.compile(optimizer='adam', # learning rate will be set by LearningRateScheduler\n",
    "                  loss='categorical_crossentropy',\n",
    "                  metrics=['accuracy'])\n",
    "    return model\n",
    "    \n",
    "with strategy.scope(): # the new way of handling distribution strategies in Tensorflow 1.14+\n",
    "    model = make_model()\n",
    "\n",
    "# print model layers\n",
    "model.summary()\n",
    "\n",
    "# set up Tensorboard logs\n",
    "timestamp = time.strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
    "log_dir=os.path.join(BUCKET, 'mnist-logs', timestamp)\n",
    "tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq=50*global_batch_size)\n",
    "print(\"Tensorboard loggs written to: \", log_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CuhDh8ao8VyB"
   },
   "source": [
    "### Train and validate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1162
    },
    "colab_type": "code",
    "id": "TTwH_P-ZJ_xx",
    "outputId": "a5be8502-8a51-4f68-80c7-ec2d586d200b"
   },
   "outputs": [],
   "source": [
    "EPOCHS = 10\n",
    "steps_per_epoch = 60000//global_batch_size  # 60,000 items in this dataset\n",
    "print(\"Step (batches) per epoch: \", steps_per_epoch)\n",
    "  \n",
    "history = model.fit(training_dataset, steps_per_epoch=steps_per_epoch, epochs=EPOCHS,\n",
    "                    validation_data=validation_dataset, validation_steps=1,\n",
    "                    callbacks=[tb_callback], verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9jFVovcUUVs1"
   },
   "source": [
    "### Visualize predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 790
    },
    "colab_type": "code",
    "id": "w12OId8Mz7dF",
    "outputId": "6f5a05f0-dac5-4bea-a713-32e51f85fc79"
   },
   "outputs": [],
   "source": [
    "# recognize digits from local fonts\n",
    "probabilities = model.predict(font_digits, steps=1)\n",
    "predicted_labels = np.argmax(probabilities, axis=1)\n",
    "display_digits(font_digits, predicted_labels, font_labels, \"predictions from local fonts (bad predictions in red)\", N)\n",
    "\n",
    "# recognize validation digits\n",
    "probabilities = model.predict(validation_digits, steps=1)\n",
    "predicted_labels = np.argmax(probabilities, axis=1)\n",
    "display_top_unrecognized(validation_digits, predicted_labels, validation_labels, N, 7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SVY1pBg5ydH-"
   },
   "source": [
    "## License"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hleIN5-pcr0N"
   },
   "source": [
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "author: Martin Gorner<br>\n",
    "twitter: @martin_gorner\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "Copyright 2020 Google LLC\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "This is not an official Google product but sample code provided for an educational purpose\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "Keras MNIST TPU (public).ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "environment": {
   "name": "tf22-cpu.2-2.m47",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf22-cpu.2-2:m47"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
